{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Octo Dataloading Examples\n",
    "\n",
    "This notebook will walk you through some of the primary features of the Octo dataloader. Data is, after all, the most important part of any machine learning pipeline!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Loading Open X-Embodiment Data\n",
    "\n",
    "The [Open X-Embodiment (OXE)](https://robotics-transformer-x.github.io/) project was a massive cross-instutition data collection collaboration the likes of which robot learning has never seen before. The resulting dataset includes 22 different robots demonstrating 527 skills and totals over 1 million trajectories. However, as we found throughout the course of the Octo project, simply loading such a diverse set of robot data is no small feat. We hope that the `octo.data` pipeline can help kickstart anyone who hopes to take advantage of the massive collection of robot demonstrations that is OXE!\n",
    "\n",
    "### Minimum working example to load a single OXE dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# minimum working example to load a single OXE dataset\n",
    "from octo.data.oxe import make_oxe_dataset_kwargs\n",
    "from octo.data.dataset import make_single_dataset\n",
    "\n",
    "dataset_kwargs = make_oxe_dataset_kwargs(\n",
    "    # see octo/data/oxe/oxe_dataset_configs.py for available datasets\n",
    "    # (this is a very small one for faster loading)\n",
    "    \"austin_buds_dataset_converted_externally_to_rlds\",\n",
    "    # can be local or on cloud storage (anything supported by TFDS)\n",
    "    # \"/path/to/base/oxe/directory\",\n",
    "    \"gs://gresearch/robotics\",\n",
    ")\n",
    "dataset = make_single_dataset(dataset_kwargs, train=True) # load the train split\n",
    "iterator = dataset.iterator()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# make_single_dataset yields entire trajectories\n",
    "traj = next(iterator)\n",
    "print(\"Top-level keys: \", traj.keys())\n",
    "print(\"Observation keys: \", traj[\"observation\"].keys())\n",
    "print(\"Task keys: \", traj[\"task\"].keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from PIL import Image\n",
    "import numpy as np\n",
    "\n",
    "traj = next(iterator)\n",
    "images = traj[\"observation\"][\"image_primary\"]\n",
    "# should be: (traj_len, window_size, height, width, channels)\n",
    "# (window_size defaults to 1)\n",
    "print(images.shape)  \n",
    "Image.fromarray(np.concatenate(images.squeeze()[-5:], axis=1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# you should set these much higher in practice (as large as your memory can hold!)\n",
    "SHUFFLE_BUFFER_SIZE = 1000\n",
    "BATCH_SIZE = 64\n",
    "\n",
    "# turning a dataset of trajectories into a training-ready batched dataset\n",
    "train_dataset = (\n",
    "    dataset.flatten() # flattens trajectories into individual frames\n",
    "    .shuffle(SHUFFLE_BUFFER_SIZE) # shuffles the frames\n",
    "    .batch(BATCH_SIZE) # batches the frames\n",
    ")\n",
    "batch = next(train_dataset.iterator())\n",
    "images = batch[\"observation\"][\"image_primary\"]\n",
    "# should be: (batch_size, window_size, height, width, channels)\n",
    "print(images.shape)\n",
    "Image.fromarray(np.concatenate(images.squeeze()[:5], axis=1))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Loading a training-ready OXE mix\n",
    "\n",
    "In reality, you're probably going to want to mix multiple datasets together, as well as use other transformations such as resizing, augmentation, windowing, etc. This section will show you how to get a proper OXE mix up and running, as well as demonstrate additional `octo.data` features for more realistic use-cases."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from octo.data.oxe import make_oxe_dataset_kwargs_and_weights\n",
    "from octo.data.dataset import make_interleaved_dataset\n",
    "\n",
    "dataset_kwargs_list, sample_weights = make_oxe_dataset_kwargs_and_weights(\n",
    "    # you can pass your own list of dataset names and sample weights here, but we've\n",
    "    # also provided a few named mixes for convenience. The Octo model was trained\n",
    "    # using the \"oxe_magic_soup\" mix.\n",
    "    \"rtx\",\n",
    "    # can be local or on cloud storage (anything supported by TFDS)\n",
    "    \"gs://gresearch/robotics\",\n",
    "    # let's get a wrist camera!\n",
    "    load_camera_views=(\"primary\", \"wrist\"),\n",
    ")\n",
    "\n",
    "# see `octo.data.dataset.make_dataset_from_rlds` for the meaning of these kwargs\n",
    "dataset_kwargs_list[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "SHUFFLE_BUFFER_SIZE = 1000\n",
    "BATCH_SIZE = 8\n",
    "\n",
    "# each element of `dataset_kwargs_list` can be used with `make_single_dataset`, but let's\n",
    "# use the more powerful `make_interleaved_dataset` to combine them for us!\n",
    "dataset = make_interleaved_dataset(\n",
    "    dataset_kwargs_list,\n",
    "    sample_weights,\n",
    "    train=True,\n",
    "    # unlike our manual shuffling above, `make_interleaved_dataset` will shuffle\n",
    "    # the JPEG-encoded images, so you should be able to fit a much larger buffer size\n",
    "    shuffle_buffer_size=SHUFFLE_BUFFER_SIZE,\n",
    "    batch_size=BATCH_SIZE,\n",
    "    # see `octo.data.dataset.apply_trajectory_transforms` for full documentation\n",
    "    # of these configuration options\n",
    "    traj_transform_kwargs=dict(\n",
    "        goal_relabeling_strategy=\"uniform\",  # let's get some goal images\n",
    "        window_size=2,  # let's get some history\n",
    "        action_horizon=4,  # let's get some future actions for action chunking\n",
    "        subsample_length=100,  # subsampling long trajectories improves shuffling a lot\n",
    "    ),\n",
    "    # see `octo.data.dataset.apply_frame_transforms` for full documentation\n",
    "    # of these configuration options\n",
    "    frame_transform_kwargs=dict(\n",
    "        # let's apply some basic image augmentations -- see `dlimp.transforms.augment_image`\n",
    "        # for full documentation of these configuration options\n",
    "        image_augment_kwargs=dict(\n",
    "            primary=dict(\n",
    "                augment_order=[\"random_resized_crop\", \"random_brightness\"],\n",
    "                random_resized_crop=dict(scale=[0.8, 1.0], ratio=[0.9, 1.1]),\n",
    "                random_brightness=[0.1],\n",
    "            )\n",
    "        ),\n",
    "        # providing a `resize_size` is highly recommended for a mixed dataset, otherwise\n",
    "        # datasets with different resolutions will cause errors\n",
    "        resize_size=dict(\n",
    "            primary=(256, 256),\n",
    "            wrist=(128, 128),\n",
    "        ),\n",
    "        # If parallelism options are not provided, they will default to tf.Data.AUTOTUNE.\n",
    "        # However, we would highly recommend setting them manually if you run into issues\n",
    "        # with memory or dataloading speed. Frame transforms are usually the speed\n",
    "        # bottleneck (due to image decoding, augmentation, and resizing), so you can set\n",
    "        # this to a very high value if you have a lot of CPU cores. Keep in mind that more\n",
    "        # parallel calls also use more memory, though.\n",
    "        num_parallel_calls=64,\n",
    "    ),\n",
    "    # Same spiel as above about performance, although trajectory transforms and data reading\n",
    "    # are usually not the speed bottleneck. One reason to manually set these is if you want\n",
    "    # to reduce memory usage (since autotune may spawn way more threads than necessary).\n",
    "    traj_transform_threads=16,\n",
    "    traj_read_threads=16,\n",
    ")\n",
    "\n",
    "# Another performance knob to tune is the number of batches to prefetch -- again,\n",
    "# the default of tf.data.AUTOTUNE can sometimes use more memory than necessary.\n",
    "iterator = dataset.iterator(prefetch=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# phew, that was a lot of configuration! Let's see what we got.\n",
    "batch = next(iterator)\n",
    "print(\"Top-level keys: \", batch.keys())\n",
    "# should now have \"image_primary\" and \"image_wrist\"!\n",
    "print(\"Observation keys: \", batch[\"observation\"].keys())\n",
    "# should also have \"image_primary\" and \"image_wrist\", corresponding to future goal images\n",
    "print(\"Task keys: \", batch[\"task\"].keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from PIL import Image\n",
    "import numpy as np\n",
    "\n",
    "images_primary = batch[\"observation\"][\"image_primary\"]\n",
    "images_wrist = batch[\"observation\"][\"image_wrist\"]\n",
    "# should be: (batch_size, window_size (now 2), height, width, channels)\n",
    "print(images_primary.shape)\n",
    "print(images_wrist.shape)\n",
    "actions = batch[\"action\"]\n",
    "# should be: (batch_size, window_size, action_horizon, action_dim)\n",
    "print(actions.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# let's visualize a window of primary images\n",
    "display(Image.fromarray(np.concatenate(images_primary[0], axis=1)))\n",
    "# now a window of wrist images -- many datasets don't have wrist images,\n",
    "# so this will often be black\n",
    "display(Image.fromarray(np.concatenate(images_wrist[0], axis=1)))\n",
    "# pad_mask_dict also tells you which keys should be treated as padding\n",
    "# (e.g., if the wrist camera is black, the corresponding pad_mask_dict entry is False)\n",
    "print(batch[\"observation\"][\"pad_mask_dict\"][\"image_wrist\"][0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# let's take a look at the \"task\" dict: it should now have both goal\n",
    "# images and language instructions!\n",
    "goal_primary = batch[\"task\"][\"image_primary\"]\n",
    "goal_wrist = batch[\"task\"][\"image_wrist\"]\n",
    "language_instruction = batch[\"task\"][\"language_instruction\"]\n",
    "display(Image.fromarray(goal_primary[0]))\n",
    "display(Image.fromarray(goal_wrist[0]))\n",
    "print(language_instruction[0])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "octo-2",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
